package com.ideaaedi.commonspring.aop;

import com.alibaba.fastjson2.JSON;
import com.ideaaedi.commonspring.filter.JsonIgnoreValueFilter;
import com.ideaaedi.commonspring.lite.params.*;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.Ordered;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 方法 入参、出参 记录
 * <p>
 * 注: 根据此AOP的逻辑， 若注解与表达式同时匹配成功，那么 注解的优先级高于表达式的优先级。
 * <p>
 * <b>特别注意:<b/>这里借助了RecordParametersAdvice的logger来记录其它地方的日志。即: 相当于其它地方将记录日志的动
 * 作委托给RecordParametersAdvice的logger来进行, 所以此logger需要能打印所有地方最下的日志级别(一般为debug)。 即:需要在配置文件中配置<code>logging.level.com
 * .ideaaedi.commonspring.aop.ParameterRecorderAdvice=debug</code> 以保证此处有“权限”记录所有用到的日志级别的日志。
 *
 * @author <font size = "20" color = "#3CAA3C"><a href="https://gitee.com/JustryDeng">JustryDeng</a></font> <img
 * src="https://gitee.com/JustryDeng/shared-files/raw/master/JustryDeng/avatar.jpg" />
 * @since 2022/6/21 11:38
 */
public class ParameterRecorderAdvice implements MethodInterceptor, InitializingBean, Ordered {
    
    public static final String BEAN_NAME = "parameterRecorderAdvice";
    
    private static final Logger log = LoggerFactory.getLogger(ParameterRecorderAdvice.class);
    
    /**
     * 无返回值
     */
    private static final String VOID_STR = void.class.getName();
    
    private static final Class<?> LOG_CLASS = log.getClass();
    
    private static final Map<String, Method> METHOD_MAP = new ConcurrentHashMap<>(8);
    
    private ParameterSerializer parameterSerializer;
    
    private RequestPathProvider requestPathProvider;
    
    private ParameterNameDiscoverer parameterNameDiscoverer;
    
    private RequestEntranceJudger requestEntranceJudger;
    
    /**
     * 需要包含的methodReference前缀 <br/> 为空则表示全部包含 <br/> methodReference形如:
     * com.ideaaedi.commonspring.aop.ParameterRecorderAdvice#init
     */
    private final Set<String> includePrefixes;
    
    /**
     * 需要排除的methodReference前缀 <br/> 为空则表示都不需要排除 <br/> methodReference形如:
     * com.ideaaedi.commonspring.aop.ParameterRecorderAdvice#init
     */
    private final Set<String> excludePrefixes;
    
    /**
     * 将参数转换为字符串的模式
     */
    private final ParameterHandleModeEnum handleMode;
    
    /**
     * 打印日志时美化打印
     */
    private final String prettyPlaceholder;
    
    /**
     * 打印日志时美化打印
     */
    @SuppressWarnings("rawtypes")
    private final Set<Class> ignoreParamTypesSet;

    /** json序列化时，忽略指定类型的值 */
    private static JsonIgnoreValueFilter JSON_IGNORE_VALUE_FILTER = null;

    /**
     * 定制器
     */
    private final ParameterRecorderCustomizer parameterRecorderCustomizer;
    
    public ParameterRecorderAdvice(Set<String> includePrefixes, Set<String> excludePrefixes,
                                   ParameterHandleModeEnum handleMode,
                                   boolean pretty, @SuppressWarnings("rawtypes") Class[] ignoreParamTypes,
                                   ParameterRecorderCustomizer parameterRecorderCustomizer) {
        this.includePrefixes = includePrefixes == null ? new HashSet<>() : includePrefixes;
        this.excludePrefixes = excludePrefixes == null ? new HashSet<>() : excludePrefixes;
        this.handleMode = handleMode;
        this.prettyPlaceholder = pretty ? "\n" : " ";
        this.ignoreParamTypesSet = new HashSet<>();
        ignoreParamTypesSet.addAll(Arrays.asList(ignoreParamTypes));
        this.parameterRecorderCustomizer = Objects.requireNonNull(parameterRecorderCustomizer, "parameterRecorderCustomizer cannot be null.");
    }
    
    public void afterPropertiesSet() throws Exception {
        String debugStr = RecordParameters.LogLevel.DEBUG.name();
        String infoStr = RecordParameters.LogLevel.INFO.name();
        String warnStr = RecordParameters.LogLevel.WARN.name();
        Method debugMethod = LOG_CLASS.getMethod(debugStr.toLowerCase(), String.class, Object.class);
        Method infoMethod = LOG_CLASS.getMethod(infoStr.toLowerCase(), String.class, Object.class);
        Method warnMethod = LOG_CLASS.getMethod(warnStr.toLowerCase(), String.class, Object.class);
        METHOD_MAP.put(debugStr, debugMethod);
        METHOD_MAP.put(infoStr, infoMethod);
        METHOD_MAP.put(warnStr, warnMethod);
        // 定制相关工具
        parameterNameDiscoverer = parameterRecorderCustomizer.customParameterNameDiscoverer();
        requestPathProvider = parameterRecorderCustomizer.customRequestPathProvider();
        parameterSerializer = parameterRecorderCustomizer.customParameterSerializer();
        requestEntranceJudger = parameterRecorderCustomizer.customRequestEntranceJudger();
        if (handleMode == ParameterHandleModeEnum.CUSTOM_SERIALIZER) {
            Objects.requireNonNull(parameterSerializer, "parameterSerializer should not be null while "
                    + "parameterHandleMode is set to 'CUSTOM_SERIALIZER'.");
        }
    }
    
    /**
     * 环绕增强
     */
    @Nullable
    @Override
    public Object invoke(@Nonnull MethodInvocation methodInvocation) throws Throwable {
        // 获取目标method
        Method targetMethod = methodInvocation.getMethod();
        // 获取目标Class
        Class<?> targetClazz = targetMethod.getDeclaringClass();
        String clazzName = targetClazz.getCanonicalName();
        if (clazzName == null) {
            clazzName = targetClazz.getName();
        }
        String methodReference = clazzName + "#" + targetMethod;
        boolean anyMatchExclude = excludePrefixes.stream().anyMatch(methodReference::startsWith);
        if (anyMatchExclude) {
            log.debug("anyMatchExclude is true, skip ParameterRecorderAdvice. methodReference is -> {}",
                    methodReference);
            return methodInvocation.proceed();
        }
        if (!CollectionUtils.isEmpty(includePrefixes)) {
            boolean anyMatchInclude = includePrefixes.stream().anyMatch(methodReference::startsWith);
            if (!anyMatchInclude) {
                log.debug("anyMatchInclude is false, skip ParameterRecorderAdvice. methodReference is -> {}",
                        methodReference);
                return methodInvocation.proceed();
            }
        }
        // 获取目标annotation
        boolean useDefaultAnnoValue = false;
        RecordParameters annotation = targetMethod.getAnnotation(RecordParameters.class);
        if (annotation == null) {
            annotation = targetClazz.getAnnotation(RecordParameters.class);
            // 如果是通过execution触发的，那么annotation可能为null, 那么给其赋予默认值即可
            if (annotation == null) {
                annotation = (RecordParameters) AnnotationUtils.getDefaultValue(RecordParameters.class);
                useDefaultAnnoValue = true;
            }
        }
        // 是否需要记录入参、出参
        boolean shouldRecordInputParams;
        boolean shouldRecordOutputParams;
        RecordParameters.LogLevel logLevel;
        if (useDefaultAnnoValue) {
            shouldRecordInputParams = shouldRecordOutputParams = true;
            logLevel = RecordParameters.LogLevel.INFO;
        } else {
            shouldRecordInputParams = annotation.strategy() == RecordParameters.Strategy.INPUT
                    ||
                    annotation.strategy() == RecordParameters.Strategy.INPUT_OUTPUT;
            shouldRecordOutputParams = annotation.strategy() == RecordParameters.Strategy.OUTPUT
                    ||
                    annotation.strategy() == RecordParameters.Strategy.INPUT_OUTPUT;
            logLevel = annotation.logLevel();
        }
        final String classMethodInfo = "Class#Method -> " + clazzName + "#" + targetMethod.getName();
        
        boolean isEntryMethod = requestEntranceJudger.ifEntrance(targetClazz, targetMethod, methodInvocation.getArguments());
        if (shouldRecordInputParams) {
            preHandle(methodInvocation, logLevel, targetMethod, classMethodInfo, isEntryMethod);
        }
        Object obj = methodInvocation.proceed();
        if (shouldRecordOutputParams) {
            postHandle(logLevel, targetMethod, obj, classMethodInfo, isEntryMethod);
        }
        return obj;
    }
    
    @Override
    public int getOrder() {
        return parameterRecorderCustomizer.customAdviceOrder();
    }
    
    /**
     * 前处理切面日志
     *
     * @param methodInvocation 方法调用信息
     * @param logLevel 日志级别
     * @param targetMethod 目标方法
     * @param classMethodInfo 目标类#方法
     * @param isEntryMethod 是否是请求入口类中的方法
     */
    private void preHandle(MethodInvocation methodInvocation, RecordParameters.LogLevel logLevel,
                           Method targetMethod, String classMethodInfo, boolean isEntryMethod) {
        StringBuilder sb = new StringBuilder(64);
        sb.append(prettyPlaceholder).append("[the way in]");
        if (isEntryMethod) {
            sb.append(" request-path[").append(getRequestPath(targetMethod)).append("] ");
        }
        sb.append(classMethodInfo);
        Object[] parameterValues = methodInvocation.getArguments();
        if (parameterValues.length > 0) {
            String[] parameterNames = parameterNameDiscoverer.getParameterNames(targetMethod);
            if (parameterNames == null) {
                log.warn("Cannot determine parameter names. curr method -> {}", targetMethod.toGenericString());
                parameterNames = new String[parameterValues.length];
                for (int i = 0; i < parameterNames.length; i++) {
                    parameterNames[i] = "UnknownParamName" + i;
                }
            }
            sb.append(", with parameters ");
            int iterationTimes = parameterValues.length;
            for (int i = 0; i < iterationTimes; i++) {
                Object parameterValue = parameterValues[i];
                String prettyStr = routeIgnoreOrPretty(parameterValue);
                sb.append(prettyPlaceholder).append("\t").append(parameterNames[i]).append(" => ").append(prettyStr);
                if (i == iterationTimes - 1) {
                    sb.append(prettyPlaceholder);
                }
            }
        } else {
            sb.append(", without any parameters");
        }
        log(logLevel, sb.toString());
    }
    
    /**
     * 后处理切面日志
     *
     * @param logLevel 日志级别
     * @param targetMethod 目标方法
     * @param obj 目标方法的返回结果
     * @param classMethodInfo 目标类#方法
     * @param isEntryMethod 是否是请求入口类中的方法
     */
    private void postHandle(RecordParameters.LogLevel logLevel, Method targetMethod,
                            Object obj, String classMethodInfo, boolean isEntryMethod) {
        StringBuilder sb = new StringBuilder(64);
        sb.append(prettyPlaceholder).append("[the way out]");
        if (isEntryMethod) {
            sb.append(" request-path[").append(getRequestPath(targetMethod)).append("] ");
        }
        sb.append(classMethodInfo);
        Class<?> returnClass = targetMethod.getReturnType();
        sb.append(prettyPlaceholder).append("\treturn type -> ").append(targetMethod.getReturnType());
        if (!VOID_STR.equals(returnClass.getName())) {
            String prettyStr = routeIgnoreOrPretty(obj);
            sb.append(prettyPlaceholder).append("\treturn result -> ").append(prettyStr);
        }
        sb.append(prettyPlaceholder);
        log(logLevel, sb.toString());
    }
    
    /**
     * 记录日志
     *
     * @param logLevel 要记录的日志的级别
     * @param markerValue formatter中占位符的值
     */
    private void log(RecordParameters.LogLevel logLevel, Object markerValue) {
        try {
            METHOD_MAP.get(logLevel.name()).invoke(log, "{}", markerValue);
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException("RecordParametersAdvice$AopSupport#log occur error!", e);
        }
    }
    
    /**
     * 将obj转换为 ignore字符串 或者 对应的格式化输出
     *
     * @param obj 需要转换的对象
     *
     * @return ignore字符串 或者 对应的格式化输出
     */
    private String routeIgnoreOrPretty(Object obj) {
        String prettyStr = null;
        if (obj != null) {
            //noinspection rawtypes
            Class hitIgnoreTypeClass = ignoreParamTypesSet.stream().filter(x -> ClassUtils.isAssignable(x,
                            obj.getClass()))
                    .findFirst().orElse(null);
            if (hitIgnoreTypeClass != null) {
                prettyStr = "<hit ignore type '" + hitIgnoreTypeClass.getName() + "'>";
            } else {
                prettyStr = doPretty(obj);
            }
        }
        return prettyStr;
    }
    
    /**
     * 格式化输出
     *
     * @param obj 需要格式化的对象
     *
     * @return 格式化后的字符串
     */
    String doPretty(Object obj) {
        String str;
        switch (handleMode) {
            case USE_JSON -> {
                try {
                    if (JSON_IGNORE_VALUE_FILTER == null) {
                        JSON_IGNORE_VALUE_FILTER = new JsonIgnoreValueFilter(ignoreParamTypesSet.toArray(new Class[0]));
                    }
                    str = JSON.toJSONString(obj, JSON_IGNORE_VALUE_FILTER);
                } catch (Throwable e) {
                    log.warn("JSON.toJSONString(obj) occur exception, e.getMessage() -> {}", e.getMessage());
                    str = String.valueOf(obj);
                }
            }
            case USE_TO_STRING -> str = String.valueOf(obj);
            case CUSTOM_SERIALIZER -> str = parameterSerializer.doSerializer(obj);
            default -> throw new UnsupportedOperationException();
        }
        return str;
    }
    
    /**
     * 获取请求path
     *
     * @param method 当前方法
     *
     * @return 请求的path
     */
    String getRequestPath(Method method) {
        try {
            return requestPathProvider == null ? "" : requestPathProvider.requestPath(method);
        } catch (Exception e) {
            log.warn("requestPathProvider.requestPath() occur exception, e -> {}", e.getMessage());
            return "";
        }
    }
}